# -*- coding:utf-8 -*-
import baostock as bs
import pandas as pd
from baostock.data.resultset import ResultData

from dao import base_dao
from model.stock_industry_info_model import StockInfo
from system.sys_logger import Logfactory
from utils import uuid_util


def get_all_stock_info():
    rs = bs.query_stock_industry()

    if rs.error_code != '0':
        Logfactory.logger.error(
            "query_stock_industry failed . and respond error_code=%s "
            "error_msg=%s", rs.error_code, rs.error_msg)
    else:
        Logfactory.logger.info(
            "query_stock_industry success . and respond error_code=%s "
            "error_msg=%s", rs.error_code, rs.error_msg)
        save_all_stock_info(rs)


def save_all_stock_info(result: ResultData):
    # 打印结果集
    industry_list = []
    while result.next():
        # 获取一条记录，将记录合并在一起
        industry_list.append(result.get_row_data())
    result = pd.DataFrame(industry_list, columns=result.fields)
    stock_info_list = get_stock_info_list(result)
    base_dao.save_all(stock_info_list)
    save_stock_info_to_file(result)


def get_stock_info_list(stock_info_frame: pd.DataFrame):
    stock_info_list = []

    if stock_info_frame is None or len(stock_info_frame) == 0:
        return stock_info_list

    for index, row in stock_info_frame.iterrows():
        stock_info = build_stock_info(row)
        if stock_info is not None:
            stock_info_list.append(stock_info)

    return stock_info_list


def build_stock_info(row: pd.Series):
    if row is None:
        Logfactory.logger.warn("baostock stock info is None")
        return None
    else:
        # 股票代码，证券代码 sh或sz.+6位数字代码，或者指数代码，如：sh.601398。sh：上海；sz：深圳。此参数不可为空
        stock_code = row['code']
        # 证券名称，股票名称
        stock_name = row['code_name']
        # 所属行业
        industry = row['industry']
        # 所属行业类别
        industry_classification = row['industryClassification']
        # 更新日期
        update_time = row['updateDate']

        stock_info = StockInfo(id=uuid_util.generate_uuid(), stock_code=stock_code, stock_name=stock_name,
                               industry=industry,
                               industry_classification=industry_classification, update_time=update_time)

        return stock_info


def save_stock_info_to_file(result: pd.DataFrame):
    if result is None:
        return
    result.to_csv("../stock_industry.csv", index=False)
